{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2f3f6cc1-e267-4774-8e1e-1da312528f1c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Copyright (c) OpenMMLab. All rights reserved.\n",
    "import argparse\n",
    "import os.path as osp\n",
    "from collections import OrderedDict\n",
    "from typing import Union\n",
    "\n",
    "import mmengine\n",
    "import torch\n",
    "from mmengine.runner.checkpoint import _load_checkpoint\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "db29e185-0931-47cd-8d2d-32077b1c6cb9",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "\n",
    "def convert_mmcls_to_timm(state_dict: Union[OrderedDict, dict]) -> OrderedDict:\n",
    "    \"\"\"Convert keys in MMClassification pretrained vit models to timm tyle.\n",
    "\n",
    "    Args:\n",
    "        state_dict (Union[OrderedDict, dict]): The state dict of\n",
    "            MMClassification pretrained vit models.\n",
    "\n",
    "    Returns:\n",
    "        OrderedDict: The converted state dict.\n",
    "    \"\"\"\n",
    "    # only keep the backbone weights and remove the backbone. prefix\n",
    "    state_dict = {\n",
    "        key.replace('backbone.', ''): value\n",
    "        for key, value in state_dict.items() if key.startswith('backbone.')\n",
    "    }\n",
    "\n",
    "    # replace projection with proj\n",
    "    state_dict = {\n",
    "        key.replace('projection', 'proj'): value\n",
    "        for key, value in state_dict.items()\n",
    "    }\n",
    "\n",
    "    # replace ffn.layers.0.0 with mlp.fc1\n",
    "    state_dict = {\n",
    "        key.replace('ffn.layers.0.0', 'mlp.fc1'): value\n",
    "        for key, value in state_dict.items()\n",
    "    }\n",
    "\n",
    "    # replace ffn.layers.1 with mlp.fc2\n",
    "    state_dict = {\n",
    "        key.replace('ffn.layers.1', 'mlp.fc2'): value\n",
    "        for key, value in state_dict.items()\n",
    "    }\n",
    "\n",
    "    # replace layers with blocks\n",
    "    state_dict = {\n",
    "        key.replace('layers', 'blocks'): value\n",
    "        for key, value in state_dict.items()\n",
    "    }\n",
    "\n",
    "    # replace ln with norm\n",
    "    state_dict = {\n",
    "        key.replace('ln', 'norm'): value\n",
    "        for key, value in state_dict.items()\n",
    "    }\n",
    "\n",
    "    # replace the last norm1 with norm\n",
    "    state_dict['norm.weight'] = state_dict.pop('norm1.weight')\n",
    "    state_dict['norm.bias'] = state_dict.pop('norm1.bias')\n",
    "\n",
    "    state_dict = OrderedDict({'model': state_dict})\n",
    "    return state_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6db6900f-8087-459a-89de-2bb97464c7cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "src = '/dssg/home/acct-medftn/medftn/BEPT/Model/mmselfsup/TCGA_Checkpoints/beitv2_backbone_cancer.pth'\n",
    "dst =  '/dssg/home/acct-medftn/medftn/BEPT/Model/mmselfsup/TCGA_Checkpoints/beitv2_backbone_cancer.pth'\n",
    "checkpoint = _load_checkpoint(src, map_location='cpu')\n",
    "if 'state_dict' in checkpoint:\n",
    "    state_dict = checkpoint['state_dict']\n",
    "else:\n",
    "    state_dict = checkpoint\n",
    "\n",
    "state_dict = convert_mmcls_to_timm(state_dict)\n",
    "mmengine.mkdir_or_exist(osp.dirname(args.dst))\n",
    "torch.save(state_dict, args.dst)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mmselfsup_yzc",
   "language": "python",
   "name": "mmselfsup_yzc"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
